--- title: Experience Replay keywords: fastai sidebar: home_sidebar summary: "Experience Replay is likely the simplest form of memory used by RL agents. " description: "Experience Replay is likely the simplest form of memory used by RL agents. " nb_path: "nbs/06a_memory.experience_replay.ipynb" ---
lets generate some batches to test with...
from fastrl.data.gym import *
source=Source(
cbs=[GymLoop(env_name='CartPole-v1',steps_delta=1,steps_count=1,seed=0),FirstLast]
)
source=Source(cbs=[GymLoop(env_name='CartPole-v1',steps_delta=1,steps_count=1,seed=0),FirstLast])
learn=fake_gym_learner(source,n=1000,bs=5)
batches=[BD(b[0]) for b in learn.dls[0]]
experience_replay=ExperienceReplay(max_sz=20,warmup_sz=19)
test_len(experience_replay,0)
what if we fill up ER? Lets add the batches, this process will happen inplace...
experience_replay+batches[0]
test_eq(experience_replay.pointer,5)
test_len(experience_replay,5)
If we add again, the total size should be 10...
experience_replay+batches[1]
test_eq(experience_replay.pointer,10)
test_len(experience_replay,10)
test_eq(experience_replay.memory['step'],(batches[0]+batches[1])['step'])
experience_replay+batches[2]
test_len(experience_replay,15)
test_eq(experience_replay.pointer,15)
test_eq(experience_replay.memory['step'],(batches[0]+batches[1]+batches[2])['step'])
experience_replay+batches[3]
test_len(experience_replay,20)
test_eq(experience_replay.pointer,20)
test_eq(experience_replay.memory['step'],(batches[0]+batches[1]+batches[2]+batches[3])['step'])
Let's verify that the steps are what we expect...
What if ER is full and we add batches? We are at the maximum memory size, we expect that the next batch added should completely overwrite the first 5 entries...
experience_replay+batches[4]
test_len(experience_replay,20)
test_eq(experience_replay.pointer,5)
test_eq(experience_replay.memory['step'],(batches[4]+batches[1]+batches[2]+batches[3])['step'])
This overwrite should properly overwrite the rest of the entries...
experience_replay+batches[5]+batches[6]+batches[7]
test_eq(experience_replay.memory['step'],(batches[4]+batches[5]+batches[6]+batches[7])['step'])
test_eq(experience_replay.pointer,20)
so we have fully overwritten the memory twice, and so far we can prove that the memory overwritting works. Let's see what happens when we append add numbered dictionaries...
experience_replay+batches[8]+batches[9]+batches[10]
test_eq(experience_replay.pointer,15)
test_eq(experience_replay.memory['step'],(batches[8]+batches[9]+batches[10]+batches[7])['step'])
What if we need to split a batch to fit at the end and beginnging of the memory? This is a possibly scary part where some of the dictionary needs to be split. Some needs to be allocated to the end of the memory, and some of it need to be allocated at the start.
single_large_batch=batches[11]+batches[12]
experience_replay+single_large_batch;
test_eq(experience_replay.pointer,5)
test_eq(experience_replay.memory['step'],(batches[12]+batches[9]+batches[10]+batches[11])['step'])
What if we sample the experience?
full_memory=(batches[12]+batches[9]+batches[10]+batches[11])
entry_ids=[str(o) for o in torch.hstack((full_memory['step'],full_memory['episode_id']))]
memory_hits=[False]*len(entry_ids)
We should be able to sample enough times that we have sampled everything. So we test this by sampling, check if that sample has been seen before, and then record that.
for i in range(6):
res,idxs=experience_replay.sample()
for o in torch.hstack((res['step'],res['episode_id'])):
memory_hits[entry_ids.index(str(o))]=True
test_eq(all(memory_hits),True)
What happens when we try to update the td_errors?
experience_replay.update_td(TensorBatch(torch.full((5,1),1.0)),torch.arange(5,10))
test_eq(experience_replay.memory['td_error'].sum(),5)
test_eq(experience_replay.memory['td_error'][torch.arange(5,10)].sum(),5)
test_eq(experience_replay.memory['td_error'][torch.arange(6,11)].sum(),4)
from fastrl.data.gym import *
source=Source(cbs=[GymLoop(env_name='CartPole-v1',steps_delta=1,steps_count=1,seed=0,mode='rgb_array'),
ResReduce(reduce_by=4),
FirstLast])
learn=fake_gym_learner(source,n=30,bs=10)
experience_replay=ExperienceReplayCallback(bs=5,max_sz=20,warmup_sz=11)
experience_replay.learn=learn
for b in learn.dls[0]:
learn.xb=b
try:
experience_replay.after_pred()
print('memory sampled')
except CancelBatchException:
print('memory is not full yet!')
import matplotlib.pyplot as plt
from IPython.display import HTML
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
experience_replay
fig = make_subplots(rows=1, cols=2, subplot_titles = ('Subplot (1,1)', 'Subplot(1,2)'))
fig.add_trace(
go.Scatter(
# x=np.arange(experience_replay.experience_replay.memory['td_error'].shape[0]),
mode='lines+markers',
y=experience_replay.experience_replay.memory['td_error'].numpy().reshape(-1,)+1)
,row=1,col=1
)
fig.add_trace(
go.Image(z=experience_replay.experience_replay.memory['image'].numpy()),
row=1,col=2
)
updatemenus = [dict(type='buttons',
buttons=[dict(label='Play',
method='animate',
args=[[f'{k}' for k in range(10)],
dict(frame=dict(duration=500, redraw=False),
transition=dict(duration=0),
easing='linear',
fromcurrent=True,
mode='immediate'
)])],
direction= 'left',
pad=dict(r= 10, t=85),
showactive =True, x= 0.1, y= 0, xanchor= 'right', yanchor= 'top')
]
sliders = [{'yanchor': 'top',
'xanchor': 'left',
'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': True, 'xanchor': 'right'},
'transition': {'duration': 500.0, 'easing': 'linear'},
'pad': {'b': 10, 't': 50},
'len': 0.9, 'x': 0.1, 'y': 0,
'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
'transition': {'duration': 0, 'easing': 'linear'}}],
'label': k, 'method': 'animate'} for k in range(10)
]}]
fig.update_layout(updatemenus=updatemenus,
sliders=sliders);
HTML(fig.to_html())
fig = px.imshow(experience_replay.experience_replay.memory['image'].numpy(),animation_frame=0)
HTML(fig.to_html())